Improve PT/TF equivalence test#16557
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
ea4923f to
df2fc56
Compare
5fa6bd2 to
b1c194d
Compare
tests/clip/test_modeling_tf_clip.py
Outdated
There was a problem hiding this comment.
No need this anymore - the test in TF common can handle nested outputs, including instances of ModelOutput.
tests/led/test_modeling_tf_led.py
Outdated
There was a problem hiding this comment.
This was done before to make TF-LED having a strong test, while the common version was still a loose test.
Now the common test is (very) strong, we no longer need this test in TF-LED test.
There was a problem hiding this comment.
Can I add import torch here without is_torch_available or require_torch? This method will be called only inside test_pt_tf_model_equivalence, which is already decorated with is_pt_tf_cross_test.
There was a problem hiding this comment.
That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.
There was a problem hiding this comment.
I don't think it really matters as it is indeed already decorated with the is_pt_tf_cross_Test. We don't have a convention set, so feel free to choose the simplest approach.
There was a problem hiding this comment.
This is the specific part for LXMERT test.
(It is possible to move this part to the common PT/TF test method. But I think it's fine/better to overwrite here.)
There was a problem hiding this comment.
Removed. The new version uses
elif tf_inputs_dict[key].dtype.is_floating:I find it's cleaner and more general.
There was a problem hiding this comment.
In the new version, this is handled in prepare_pt_inputs_from_tf_inputs.
if isinstance(value, dict):
pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
elif isinstance(value, (list, tuple)):
pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value)There was a problem hiding this comment.
In the new version, we only need to overwrite prepare_pt_inputs_from_tf_inputs, because that is the place with actual differences from the common version.
There was a problem hiding this comment.
I prefer to call super() here, because the difference is only about adding a noise argument in the block above.
There was a problem hiding this comment.
We just need to overwrite check_pt_tf_models.
tests/test_modeling_tf_common.py
Outdated
There was a problem hiding this comment.
not sure if we should test this argument. I think it is not worth it.
There was a problem hiding this comment.
Now sure why it was added, but it doesn't look useful I agree.
There was a problem hiding this comment.
It was added by me during the process: sometimes I passed the wrong arguments and got errors.
However, those arguments are unlikely to be used by anyone else (unless someone want to change check_pt_tf_outputs)
tests/test_modeling_tf_common.py
Outdated
There was a problem hiding this comment.
make the failure message more informative by adding the corresponding tensor name, like
output.hidden_states
sgugger
left a comment
There was a problem hiding this comment.
Thanks for cleaning those. It's great we can remove some model-specific code to rely on the generic common tests!
There was a problem hiding this comment.
That's just a marker that reads an env variable, so I think it should have the require_torch just in case, but I'm not sure if we are very consistent with that. @LysandreJik might know better.
tests/test_modeling_tf_common.py
Outdated
There was a problem hiding this comment.
Now sure why it was added, but it doesn't look useful I agree.
gante
left a comment
There was a problem hiding this comment.
This is great, it makes writing tests for edge cases much easier 🚀
cdae60f to
b703e6c
Compare
|
(just rebase on main - no real change since your last review) |
|
Merge now. Don't hesitate to leave comments in any :-) |
* add error message * Use names in the error message * allow ModelOutput * rename to check_pt_tf_outputs and move outside * fix style * skip past_key_values in a better way * Add comments * improve code for label/loss * make the logic clear by moving the ignore keys out * fix _postprocessing_to_ignore * fix _postprocessing_to_ignore: create new outputs from the remaining fields * ignore past_key_values in TFGPT2 models for now * make check_pt_tf_outputs better regarding names * move check_pt_tf_models outside * rename methods * remove test_pt_tf_model_equivalence in TFCLIPModelTest * Reduce TFViTMAEModelTest.test_pt_tf_model_equivalence * move prepare_pt_inputs_from_tf_inputs outside check_pt_tf_models * Fix quality * Clean-up TFLxmertModelTester.test_pt_tf_model_equivalence * Fix quality * fix * fix style * Clean-up TFLEDModelTest.test_pt_tf_model_equivalence * Fix quality * add docstring * improve comment Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
What does this PR do?
Improve PT/TF equivalence test.
To make the review a bit easier for you, I made some comments. And here are a summary of changes:
test_pt_tf_model_equivalencein TensorFlowLEDandCLIPare removed: the common one can handle it.test_pt_tf_model_equivalencein TensorFlowLXMERTandViTMAEare removed: we only need to overwriteprepare_pt_inputs_from_tf_inputsforLXMERTcheck_pt_tf_modelsforViTMAETFModelTesterMixin.test_pt_tf_model_equivalence_make_attention_mask_non_null_postprocessing_to_ignore_test_casescheck_pt_tf_outputs:ModelOutput(for CLIP model)output.hidden_statesoroutput.text_model_output.attentions_1Once this PR is approved/merged: